{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from comet_ml import Experiment\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import models\n",
    "import argparse\n",
    "from helper import *\n",
    "torch.cuda.set_device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "torch.cuda.manual_seed(0)\n",
    "\n",
    "classes = ['sky', 'building', 'pole', 'road', 'pavement', 'tree', 'signsymbol', 'fence', 'car', 'pedestrian', 'bicyclist', 'unlabelled']\n",
    "DATA_DIR = '../data/CamVid/'\n",
    "x_train_dir = os.path.join(DATA_DIR, 'train')\n",
    "y_train_dir = os.path.join(DATA_DIR, 'trainannot')\n",
    "\n",
    "x_valid_dir = os.path.join(DATA_DIR, 'val')\n",
    "y_valid_dir = os.path.join(DATA_DIR, 'valannot')\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean = [0.41189489566336, 0.4251328133025, 0.4326707089857], std = [0.27413549931506, 0.28506257482912, 0.28284674400252])\n",
    "])\n",
    "\n",
    "train_dataset = CamVid(\n",
    "    x_train_dir,\n",
    "    y_train_dir,\n",
    "    classes = classes,\n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "valid_dataset = CamVid(\n",
    "    x_valid_dir,\n",
    "    y_valid_dir,\n",
    "    classes = classes, \n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)\n",
    "valloader = DataLoader(valid_dataset, batch_size = 1, shuffle = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "student = models.unet.Unet('resnet26', classes = 12, encoder_weights = None).cuda()\n",
    "sf = SaveFeatures(student.encoder.relu)\n",
    "teacher = models.unet.Unet('resnet34', classes = 12, encoder_weights = None).cuda()\n",
    "teacher.load_state_dict(torch.load('../saved_models/resnet34/pretrained_0.pt'))\n",
    "sf2 = SaveFeatures(teacher.encoder.relu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder.conv1.weight torch.Size([64, 3, 7, 7]) True\n",
      "encoder.bn1.weight torch.Size([64]) True\n",
      "encoder.bn1.bias torch.Size([64]) True\n",
      "encoder.layer1.0.conv1.weight torch.Size([64, 64, 3, 3]) True\n",
      "encoder.layer1.0.bn1.weight torch.Size([64]) True\n",
      "encoder.layer1.0.bn1.bias torch.Size([64]) True\n",
      "encoder.layer1.0.conv2.weight torch.Size([64, 64, 3, 3]) True\n",
      "encoder.layer1.0.bn2.weight torch.Size([64]) True\n",
      "encoder.layer1.0.bn2.bias torch.Size([64]) True\n",
      "encoder.layer1.1.conv1.weight torch.Size([64, 64, 3, 3]) True\n",
      "encoder.layer1.1.bn1.weight torch.Size([64]) True\n",
      "encoder.layer1.1.bn1.bias torch.Size([64]) True\n",
      "encoder.layer1.1.conv2.weight torch.Size([64, 64, 3, 3]) True\n",
      "encoder.layer1.1.bn2.weight torch.Size([64]) True\n",
      "encoder.layer1.1.bn2.bias torch.Size([64]) True\n",
      "encoder.layer1.2.conv1.weight torch.Size([64, 64, 3, 3]) True\n",
      "encoder.layer1.2.bn1.weight torch.Size([64]) True\n",
      "encoder.layer1.2.bn1.bias torch.Size([64]) True\n",
      "encoder.layer1.2.conv2.weight torch.Size([64, 64, 3, 3]) True\n",
      "encoder.layer1.2.bn2.weight torch.Size([64]) True\n",
      "encoder.layer1.2.bn2.bias torch.Size([64]) True\n",
      "encoder.layer2.0.conv1.weight torch.Size([128, 64, 3, 3]) True\n",
      "encoder.layer2.0.bn1.weight torch.Size([128]) True\n",
      "encoder.layer2.0.bn1.bias torch.Size([128]) True\n",
      "encoder.layer2.0.conv2.weight torch.Size([128, 128, 3, 3]) True\n",
      "encoder.layer2.0.bn2.weight torch.Size([128]) True\n",
      "encoder.layer2.0.bn2.bias torch.Size([128]) True\n",
      "encoder.layer2.0.downsample.0.weight torch.Size([128, 64, 1, 1]) True\n",
      "encoder.layer2.0.downsample.1.weight torch.Size([128]) True\n",
      "encoder.layer2.0.downsample.1.bias torch.Size([128]) True\n",
      "encoder.layer2.1.conv1.weight torch.Size([128, 128, 3, 3]) True\n",
      "encoder.layer2.1.bn1.weight torch.Size([128]) True\n",
      "encoder.layer2.1.bn1.bias torch.Size([128]) True\n",
      "encoder.layer2.1.conv2.weight torch.Size([128, 128, 3, 3]) True\n",
      "encoder.layer2.1.bn2.weight torch.Size([128]) True\n",
      "encoder.layer2.1.bn2.bias torch.Size([128]) True\n",
      "encoder.layer2.2.conv1.weight torch.Size([128, 128, 3, 3]) True\n",
      "encoder.layer2.2.bn1.weight torch.Size([128]) True\n",
      "encoder.layer2.2.bn1.bias torch.Size([128]) True\n",
      "encoder.layer2.2.conv2.weight torch.Size([128, 128, 3, 3]) True\n",
      "encoder.layer2.2.bn2.weight torch.Size([128]) True\n",
      "encoder.layer2.2.bn2.bias torch.Size([128]) True\n",
      "encoder.layer3.0.conv1.weight torch.Size([256, 128, 3, 3]) True\n",
      "encoder.layer3.0.bn1.weight torch.Size([256]) True\n",
      "encoder.layer3.0.bn1.bias torch.Size([256]) True\n",
      "encoder.layer3.0.conv2.weight torch.Size([256, 256, 3, 3]) True\n",
      "encoder.layer3.0.bn2.weight torch.Size([256]) True\n",
      "encoder.layer3.0.bn2.bias torch.Size([256]) True\n",
      "encoder.layer3.0.downsample.0.weight torch.Size([256, 128, 1, 1]) True\n",
      "encoder.layer3.0.downsample.1.weight torch.Size([256]) True\n",
      "encoder.layer3.0.downsample.1.bias torch.Size([256]) True\n",
      "encoder.layer3.1.conv1.weight torch.Size([256, 256, 3, 3]) True\n",
      "encoder.layer3.1.bn1.weight torch.Size([256]) True\n",
      "encoder.layer3.1.bn1.bias torch.Size([256]) True\n",
      "encoder.layer3.1.conv2.weight torch.Size([256, 256, 3, 3]) True\n",
      "encoder.layer3.1.bn2.weight torch.Size([256]) True\n",
      "encoder.layer3.1.bn2.bias torch.Size([256]) True\n",
      "encoder.layer3.2.conv1.weight torch.Size([256, 256, 3, 3]) True\n",
      "encoder.layer3.2.bn1.weight torch.Size([256]) True\n",
      "encoder.layer3.2.bn1.bias torch.Size([256]) True\n",
      "encoder.layer3.2.conv2.weight torch.Size([256, 256, 3, 3]) True\n",
      "encoder.layer3.2.bn2.weight torch.Size([256]) True\n",
      "encoder.layer3.2.bn2.bias torch.Size([256]) True\n",
      "encoder.layer4.0.conv1.weight torch.Size([512, 256, 3, 3]) True\n",
      "encoder.layer4.0.bn1.weight torch.Size([512]) True\n",
      "encoder.layer4.0.bn1.bias torch.Size([512]) True\n",
      "encoder.layer4.0.conv2.weight torch.Size([512, 512, 3, 3]) True\n",
      "encoder.layer4.0.bn2.weight torch.Size([512]) True\n",
      "encoder.layer4.0.bn2.bias torch.Size([512]) True\n",
      "encoder.layer4.0.downsample.0.weight torch.Size([512, 256, 1, 1]) True\n",
      "encoder.layer4.0.downsample.1.weight torch.Size([512]) True\n",
      "encoder.layer4.0.downsample.1.bias torch.Size([512]) True\n",
      "encoder.layer4.1.conv1.weight torch.Size([512, 512, 3, 3]) True\n",
      "encoder.layer4.1.bn1.weight torch.Size([512]) True\n",
      "encoder.layer4.1.bn1.bias torch.Size([512]) True\n",
      "encoder.layer4.1.conv2.weight torch.Size([512, 512, 3, 3]) True\n",
      "encoder.layer4.1.bn2.weight torch.Size([512]) True\n",
      "encoder.layer4.1.bn2.bias torch.Size([512]) True\n",
      "encoder.layer4.2.conv1.weight torch.Size([512, 512, 3, 3]) True\n",
      "encoder.layer4.2.bn1.weight torch.Size([512]) True\n",
      "encoder.layer4.2.bn1.bias torch.Size([512]) True\n",
      "encoder.layer4.2.conv2.weight torch.Size([512, 512, 3, 3]) True\n",
      "encoder.layer4.2.bn2.weight torch.Size([512]) True\n",
      "encoder.layer4.2.bn2.bias torch.Size([512]) True\n",
      "decoder.blocks.0.conv1.0.weight torch.Size([256, 768, 3, 3]) True\n",
      "decoder.blocks.0.conv1.1.weight torch.Size([256]) True\n",
      "decoder.blocks.0.conv1.1.bias torch.Size([256]) True\n",
      "decoder.blocks.0.conv2.0.weight torch.Size([256, 256, 3, 3]) True\n",
      "decoder.blocks.0.conv2.1.weight torch.Size([256]) True\n",
      "decoder.blocks.0.conv2.1.bias torch.Size([256]) True\n",
      "decoder.blocks.1.conv1.0.weight torch.Size([128, 384, 3, 3]) True\n",
      "decoder.blocks.1.conv1.1.weight torch.Size([128]) True\n",
      "decoder.blocks.1.conv1.1.bias torch.Size([128]) True\n",
      "decoder.blocks.1.conv2.0.weight torch.Size([128, 128, 3, 3]) True\n",
      "decoder.blocks.1.conv2.1.weight torch.Size([128]) True\n",
      "decoder.blocks.1.conv2.1.bias torch.Size([128]) True\n",
      "decoder.blocks.2.conv1.0.weight torch.Size([64, 192, 3, 3]) True\n",
      "decoder.blocks.2.conv1.1.weight torch.Size([64]) True\n",
      "decoder.blocks.2.conv1.1.bias torch.Size([64]) True\n",
      "decoder.blocks.2.conv2.0.weight torch.Size([64, 64, 3, 3]) True\n",
      "decoder.blocks.2.conv2.1.weight torch.Size([64]) True\n",
      "decoder.blocks.2.conv2.1.bias torch.Size([64]) True\n",
      "decoder.blocks.3.conv1.0.weight torch.Size([32, 128, 3, 3]) True\n",
      "decoder.blocks.3.conv1.1.weight torch.Size([32]) True\n",
      "decoder.blocks.3.conv1.1.bias torch.Size([32]) True\n",
      "decoder.blocks.3.conv2.0.weight torch.Size([32, 32, 3, 3]) True\n",
      "decoder.blocks.3.conv2.1.weight torch.Size([32]) True\n",
      "decoder.blocks.3.conv2.1.bias torch.Size([32]) True\n",
      "decoder.blocks.4.conv1.0.weight torch.Size([16, 32, 3, 3]) True\n",
      "decoder.blocks.4.conv1.1.weight torch.Size([16]) True\n",
      "decoder.blocks.4.conv1.1.bias torch.Size([16]) True\n",
      "decoder.blocks.4.conv2.0.weight torch.Size([16, 16, 3, 3]) True\n",
      "decoder.blocks.4.conv2.1.weight torch.Size([16]) True\n",
      "decoder.blocks.4.conv2.1.bias torch.Size([16]) True\n",
      "segmentation_head.0.weight torch.Size([12, 16, 3, 3]) True\n",
      "segmentation_head.0.bias torch.Size([12]) True\n"
     ]
    }
   ],
   "source": [
    "for name, param in student.named_parameters() : \n",
    "    print(name, param.shape, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder.conv1.weight torch.Size([64, 3, 7, 7]) True\n",
      "encoder.bn1.weight torch.Size([64]) True\n",
      "encoder.bn1.bias torch.Size([64]) True\n",
      "encoder.layer1.0.conv1.weight torch.Size([64, 64, 3, 3]) False\n",
      "encoder.layer1.0.bn1.weight torch.Size([64]) False\n",
      "encoder.layer1.0.bn1.bias torch.Size([64]) False\n",
      "encoder.layer1.0.conv2.weight torch.Size([64, 64, 3, 3]) False\n",
      "encoder.layer1.0.bn2.weight torch.Size([64]) False\n",
      "encoder.layer1.0.bn2.bias torch.Size([64]) False\n",
      "encoder.layer1.1.conv1.weight torch.Size([64, 64, 3, 3]) False\n",
      "encoder.layer1.1.bn1.weight torch.Size([64]) False\n",
      "encoder.layer1.1.bn1.bias torch.Size([64]) False\n",
      "encoder.layer1.1.conv2.weight torch.Size([64, 64, 3, 3]) False\n",
      "encoder.layer1.1.bn2.weight torch.Size([64]) False\n",
      "encoder.layer1.1.bn2.bias torch.Size([64]) False\n",
      "encoder.layer1.2.conv1.weight torch.Size([64, 64, 3, 3]) False\n",
      "encoder.layer1.2.bn1.weight torch.Size([64]) False\n",
      "encoder.layer1.2.bn1.bias torch.Size([64]) False\n",
      "encoder.layer1.2.conv2.weight torch.Size([64, 64, 3, 3]) False\n",
      "encoder.layer1.2.bn2.weight torch.Size([64]) False\n",
      "encoder.layer1.2.bn2.bias torch.Size([64]) False\n",
      "encoder.layer2.0.conv1.weight torch.Size([128, 64, 3, 3]) False\n",
      "encoder.layer2.0.bn1.weight torch.Size([128]) False\n",
      "encoder.layer2.0.bn1.bias torch.Size([128]) False\n",
      "encoder.layer2.0.conv2.weight torch.Size([128, 128, 3, 3]) False\n",
      "encoder.layer2.0.bn2.weight torch.Size([128]) False\n",
      "encoder.layer2.0.bn2.bias torch.Size([128]) False\n",
      "encoder.layer2.0.downsample.0.weight torch.Size([128, 64, 1, 1]) False\n",
      "encoder.layer2.0.downsample.1.weight torch.Size([128]) False\n",
      "encoder.layer2.0.downsample.1.bias torch.Size([128]) False\n",
      "encoder.layer2.1.conv1.weight torch.Size([128, 128, 3, 3]) False\n",
      "encoder.layer2.1.bn1.weight torch.Size([128]) False\n",
      "encoder.layer2.1.bn1.bias torch.Size([128]) False\n",
      "encoder.layer2.1.conv2.weight torch.Size([128, 128, 3, 3]) False\n",
      "encoder.layer2.1.bn2.weight torch.Size([128]) False\n",
      "encoder.layer2.1.bn2.bias torch.Size([128]) False\n",
      "encoder.layer2.2.conv1.weight torch.Size([128, 128, 3, 3]) False\n",
      "encoder.layer2.2.bn1.weight torch.Size([128]) False\n",
      "encoder.layer2.2.bn1.bias torch.Size([128]) False\n",
      "encoder.layer2.2.conv2.weight torch.Size([128, 128, 3, 3]) False\n",
      "encoder.layer2.2.bn2.weight torch.Size([128]) False\n",
      "encoder.layer2.2.bn2.bias torch.Size([128]) False\n",
      "encoder.layer3.0.conv1.weight torch.Size([256, 128, 3, 3]) False\n",
      "encoder.layer3.0.bn1.weight torch.Size([256]) False\n",
      "encoder.layer3.0.bn1.bias torch.Size([256]) False\n",
      "encoder.layer3.0.conv2.weight torch.Size([256, 256, 3, 3]) False\n",
      "encoder.layer3.0.bn2.weight torch.Size([256]) False\n",
      "encoder.layer3.0.bn2.bias torch.Size([256]) False\n",
      "encoder.layer3.0.downsample.0.weight torch.Size([256, 128, 1, 1]) False\n",
      "encoder.layer3.0.downsample.1.weight torch.Size([256]) False\n",
      "encoder.layer3.0.downsample.1.bias torch.Size([256]) False\n",
      "encoder.layer3.1.conv1.weight torch.Size([256, 256, 3, 3]) False\n",
      "encoder.layer3.1.bn1.weight torch.Size([256]) False\n",
      "encoder.layer3.1.bn1.bias torch.Size([256]) False\n",
      "encoder.layer3.1.conv2.weight torch.Size([256, 256, 3, 3]) False\n",
      "encoder.layer3.1.bn2.weight torch.Size([256]) False\n",
      "encoder.layer3.1.bn2.bias torch.Size([256]) False\n",
      "encoder.layer3.2.conv1.weight torch.Size([256, 256, 3, 3]) False\n",
      "encoder.layer3.2.bn1.weight torch.Size([256]) False\n",
      "encoder.layer3.2.bn1.bias torch.Size([256]) False\n",
      "encoder.layer3.2.conv2.weight torch.Size([256, 256, 3, 3]) False\n",
      "encoder.layer3.2.bn2.weight torch.Size([256]) False\n",
      "encoder.layer3.2.bn2.bias torch.Size([256]) False\n",
      "encoder.layer4.0.conv1.weight torch.Size([512, 256, 3, 3]) False\n",
      "encoder.layer4.0.bn1.weight torch.Size([512]) False\n",
      "encoder.layer4.0.bn1.bias torch.Size([512]) False\n",
      "encoder.layer4.0.conv2.weight torch.Size([512, 512, 3, 3]) False\n",
      "encoder.layer4.0.bn2.weight torch.Size([512]) False\n",
      "encoder.layer4.0.bn2.bias torch.Size([512]) False\n",
      "encoder.layer4.0.downsample.0.weight torch.Size([512, 256, 1, 1]) False\n",
      "encoder.layer4.0.downsample.1.weight torch.Size([512]) False\n",
      "encoder.layer4.0.downsample.1.bias torch.Size([512]) False\n",
      "encoder.layer4.1.conv1.weight torch.Size([512, 512, 3, 3]) False\n",
      "encoder.layer4.1.bn1.weight torch.Size([512]) False\n",
      "encoder.layer4.1.bn1.bias torch.Size([512]) False\n",
      "encoder.layer4.1.conv2.weight torch.Size([512, 512, 3, 3]) False\n",
      "encoder.layer4.1.bn2.weight torch.Size([512]) False\n",
      "encoder.layer4.1.bn2.bias torch.Size([512]) False\n",
      "encoder.layer4.2.conv1.weight torch.Size([512, 512, 3, 3]) False\n",
      "encoder.layer4.2.bn1.weight torch.Size([512]) False\n",
      "encoder.layer4.2.bn1.bias torch.Size([512]) False\n",
      "encoder.layer4.2.conv2.weight torch.Size([512, 512, 3, 3]) False\n",
      "encoder.layer4.2.bn2.weight torch.Size([512]) False\n",
      "encoder.layer4.2.bn2.bias torch.Size([512]) False\n",
      "decoder.blocks.0.conv1.0.weight torch.Size([256, 768, 3, 3]) False\n",
      "decoder.blocks.0.conv1.1.weight torch.Size([256]) False\n",
      "decoder.blocks.0.conv1.1.bias torch.Size([256]) False\n",
      "decoder.blocks.0.conv2.0.weight torch.Size([256, 256, 3, 3]) False\n",
      "decoder.blocks.0.conv2.1.weight torch.Size([256]) False\n",
      "decoder.blocks.0.conv2.1.bias torch.Size([256]) False\n",
      "decoder.blocks.1.conv1.0.weight torch.Size([128, 384, 3, 3]) False\n",
      "decoder.blocks.1.conv1.1.weight torch.Size([128]) False\n",
      "decoder.blocks.1.conv1.1.bias torch.Size([128]) False\n",
      "decoder.blocks.1.conv2.0.weight torch.Size([128, 128, 3, 3]) False\n",
      "decoder.blocks.1.conv2.1.weight torch.Size([128]) False\n",
      "decoder.blocks.1.conv2.1.bias torch.Size([128]) False\n",
      "decoder.blocks.2.conv1.0.weight torch.Size([64, 192, 3, 3]) False\n",
      "decoder.blocks.2.conv1.1.weight torch.Size([64]) False\n",
      "decoder.blocks.2.conv1.1.bias torch.Size([64]) False\n",
      "decoder.blocks.2.conv2.0.weight torch.Size([64, 64, 3, 3]) False\n",
      "decoder.blocks.2.conv2.1.weight torch.Size([64]) False\n",
      "decoder.blocks.2.conv2.1.bias torch.Size([64]) False\n",
      "decoder.blocks.3.conv1.0.weight torch.Size([32, 128, 3, 3]) False\n",
      "decoder.blocks.3.conv1.1.weight torch.Size([32]) False\n",
      "decoder.blocks.3.conv1.1.bias torch.Size([32]) False\n",
      "decoder.blocks.3.conv2.0.weight torch.Size([32, 32, 3, 3]) False\n",
      "decoder.blocks.3.conv2.1.weight torch.Size([32]) False\n",
      "decoder.blocks.3.conv2.1.bias torch.Size([32]) False\n",
      "decoder.blocks.4.conv1.0.weight torch.Size([16, 32, 3, 3]) False\n",
      "decoder.blocks.4.conv1.1.weight torch.Size([16]) False\n",
      "decoder.blocks.4.conv1.1.bias torch.Size([16]) False\n",
      "decoder.blocks.4.conv2.0.weight torch.Size([16, 16, 3, 3]) False\n",
      "decoder.blocks.4.conv2.1.weight torch.Size([16]) False\n",
      "decoder.blocks.4.conv2.1.bias torch.Size([16]) False\n",
      "segmentation_head.0.weight torch.Size([12, 16, 3, 3]) False\n",
      "segmentation_head.0.bias torch.Size([12]) False\n"
     ]
    }
   ],
   "source": [
    "def unfreeze(model, stage) : \n",
    "    if stage == 0 :\n",
    "        for name, param in model.named_parameters() :\n",
    "            param.requires_grad = False\n",
    "            if name.startswith('encoder.conv') or name.startswith('encoder.bn') : \n",
    "                param.requires_grad = True\n",
    "    \n",
    "    elif stage > 0 and stage < 5 : \n",
    "        for name, param in model.named_parameters() :\n",
    "            param.requires_grad = False\n",
    "            if name.startswith('encoder.layer' + str(stage)) : \n",
    "                param.requires_grad = True\n",
    "    \n",
    "    elif stage > 4 and stage < 10 : \n",
    "        for name, param in model.named_parameters() :\n",
    "            param.requires_grad = False\n",
    "            if name.startswith('decoder.blocks.' + str(stage - 5)) : \n",
    "                param.requires_grad = True\n",
    "    \n",
    "    elif stage == 10 : \n",
    "        for name, param in model.named_parameters() :\n",
    "            param.requires_grad = False\n",
    "            if name.startswith('segmentation') :\n",
    "                param.requires_grad = True\n",
    "    \n",
    "    else :\n",
    "        print('Invalid stage input: only integers from 0 to 10 are valid')\n",
    "    \n",
    "    return model\n",
    "\n",
    "student = models.unet.Unet('resnet26', classes = 12, encoder_weights = None).cuda()\n",
    "stage = 0\n",
    "student = unfreeze(student, stage)\n",
    "for name, param in student.named_parameters() : \n",
    "    print(name, param.shape, param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (pyt)",
   "language": "python",
   "name": "pyt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
